import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy

import json
import numpy as np

from swift.llm import InferRequest, SamplingArguments
from swift.llm.infer.protocol import UsageInfo
from swift.utils import get_logger
from .base import Sampler
from .utils import get_reward, perform_infer

logger = get_logger()

NXT_PROMPT = """Continue.
"""

next_message = {
    "role": "user",
    "content": NXT_PROMPT,
}


class LanguageNode:

    def __init__(
        self,
        step: str = None,
        sep_token: str = None,
        parent: "LanguageNode" = None,
    ):
        self.parent = parent

        if sep_token:
            self.sep_token = sep_token
        else:
            self.sep_token = parent.sep_token

        if parent:
            self.path = parent.path[:] + [step]
            self.answer = parent.answer + step + self.sep_token
            self.depth = parent.depth + 1
        else:
            self.path = []
            self.answer = ""
            self.depth = 0

        self.active_children = []
        self.children = []
        self.visit_count = 0
        self.process_reward = 0.0
        self.outcome_reward = 0.0
        self.terminated = False
        self.correct = False

    def is_leaf(self):
        return len(self.children) == 0

    def is_root(self):
        return self.parent is None

    def visit(self):
        self.visit_count += 1

    def init_and_update_value(self, value):
        self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (
            self.visit_count + 1
        )

    def add_child(self, child: "LanguageNode"):
        self.children.append(child)
        if not child.terminated:
            self.active_children.append(child)

    def collect(self):
        result = {
            "path": self.path,
            "depth": self.depth,
            "visit_count": self.visit_count,
            "process_reward": self.process_reward,
            "outcome_reward": self.outcome_reward,
            "terminated": str(self.terminated),
            "correct": str(self.correct),
            "children": [child.collect() for child in self.children],
        }
        return result

    def __lt__(self, other):
        return self.outcome_reward < other.outcome_reward


class MctsSampler(Sampler):

    def __init__(self, input_args: SamplingArguments):
        super().__init__(input_args)
        self.usage_info = UsageInfo(0, 0, 0)

    def _prepare_model_tokenizer(self):
        args = self.args
        self.infer_kwargs = {}
        if args.sampler_engine == "client":
            from swift.llm import InferClient

            api_key = args.api_key
            base_url = args.base_url
            self.infer_engine = [
                InferClient(base_url=base_url, api_key=api_key)
                for _ in range(args.num_return_sequences)
            ]
            self.infer_kwargs["model"] = args.model
        else:
            _Engine = self.get_infer_engine()
            self.infer_engine = _Engine(
                self.args.model,
                model_type=self.args.model_type,
                **self.args.engine_kwargs,
            )

    def get_infer_engine(self):
        if self.args.sampler_engine == "pt":
            from swift.llm import PtEngine

            _Engine = PtEngine
        elif self.args.sampler_engine == "vllm":
            from swift.llm import VllmEngine

            _Engine = VllmEngine
        elif self.args.sampler_engine == "lmdeploy":
            from swift.llm import LmdeployEngine

            _Engine = LmdeployEngine
        elif self.args.sampler_engine == "no":
            _Engine = None
        else:
            raise ValueError(f"Cannot find engine name: {self.args.sampler_engine}")
        return _Engine

    def _prepare_template(self) -> None:
        # Hack from super()
        self._prepare_request_configs()

    def _prepare_request_configs(self):
        _args = self.args
        request_config = _args.get_request_config()
        request_config.stop = _args.stop_words
        request_config.seed = _args.seed
        self.expand_request_configs = []
        self.rollout_request_configs = []
        for i in range(_args.num_return_sequences):
            expand_request_config = deepcopy(request_config)
            expand_request_config.n = 1
            expand_request_config.num_beams = expand_request_config.n
            expand_request_config.seed += i
            self.expand_request_configs.append(expand_request_config)
            rollout_request_config = deepcopy(request_config)
            rollout_request_config.max_tokens = 500
            rollout_request_config.temperature = 0.0
            rollout_request_config.n = 1
            self.rollout_request_configs.append(rollout_request_config)

    def update_usage_info(self, response):
        for key, value in self.usage_info.__dict__.items():
            update_value = getattr(response.usage, key, None) + value
            setattr(self.usage_info, key, update_value)

    def search_single(self, query, ground_truth):

        def _uct(uct_curr_node: LanguageNode):
            alpha = _args.process_reward_rate
            value = (
                alpha * uct_curr_node.process_reward
                + (1 - alpha) * uct_curr_node.outcome_reward
            )
            if uct_curr_node.is_root():
                return value

            exploitation_score = value
            exploration_score = _args.exploration_rate * np.sqrt(
                np.log(uct_curr_node.parent.visit_count + 1)
                / (uct_curr_node.visit_count + 1)
            )

            return exploration_score + exploitation_score

        def _select(select_curr_node: LanguageNode):
            while not select_curr_node.is_leaf():
                select_curr_node = max(
                    select_curr_node.active_children, key=lambda x: _uct(x)
                )
            return select_curr_node

        def _expand(expand_curr_node: LanguageNode):
            n = _args.num_return_sequences - len(expand_curr_node.children)
            if expand_curr_node.is_root():
                infer_requests = [
                    InferRequest(system_message + [prompt_message]) for _ in range(n)
                ]
            else:
                history_message = {
                    "role": "assistant",
                    "content": expand_curr_node.answer,
                }
                infer_request = InferRequest(
                    system_message + [prompt_message, history_message, next_message]
                )
                infer_requests = [infer_request for _ in range(n)]

            # e_time = time.time()
            # To perform the Expand operation in parallel,
            # there's no need to consider the order for now, since the Prompt is the same.
            expand_iter_index = 0
            while True:
                responses = perform_infer(
                    self.infer_engine,
                    infer_requests,
                    self.expand_request_configs,
                    **self.infer_kwargs,
                )
                if len(responses) > 0:
                    break
                if expand_iter_index == 5:
                    raise ValueError("Expand should not return any response")
                expand_iter_index += 1
            # logger.info(f"expand.expand time: {time.time() - e_time}")

            # To fetch Outcome Reward in parallel,
            # the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly.
            orm_infer_requests = []
            unique_output = set()
            for response in responses:
                self.update_usage_info(response)
                output = (
                    response.choices[0]
                    .message.content.rstrip(sep_token + "\n")
                    .split(sep_token)[0]
                )
                if output in unique_output:
                    continue
                unique_output.add(output)
                orm_infer_requests.append(
                    InferRequest([{"role": "assistant", "content": output}])
                )
                child = LanguageNode(step=output, parent=expand_curr_node)
                if self.orm_model.check_terminate(child.answer)[0]:
                    child.terminated = True
                expand_curr_node.add_child(child)

            # e_time = time.time()
            orm_score, _orm_mask = get_reward(
                self.orm_model,
                orm_infer_requests,
                ground_truths=[ground_truth] * len(orm_infer_requests),
                threshold=0.0,
            )
            # logger.info(f"expand.orm time: {time.time() - e_time}")
            for child, score in zip(expand_curr_node.children, orm_score):
                if child.terminated:
                    child.init_and_update_value(score)
                    child.correct = score > 0.9
                    terminated_nodes.append(child)

            # e_time = time.time()
            if self.prm_model:
                prm_infer_requests = []
                for child in expand_curr_node.children:
                    prm_message = {"role": "assistant", "content": child.answer}
                    prm_infer_requests.append(
                        InferRequest([prompt_message, prm_message])
                    )
                prm_score, _prm_mask = get_reward(
                    self.prm_model,
                    prm_infer_requests,
                    ground_truths=[ground_truth] * len(prm_infer_requests),
                    threshold=0.0,
                )
                for child, score in zip(expand_curr_node.children, prm_score):
                    child.process_reward = score
            # logger.info(f"expand.prm time: {time.time() - e_time}")

        def _rollout(rollout_curr_node: LanguageNode):
            rollout_depth = 0
            rollout_nodes = {}
            for i in range(len(rollout_curr_node.active_children)):
                rollout_nodes[i] = {
                    "node": rollout_curr_node.active_children[i],
                    "history_messages": {
                        "role": "assistant",
                        "content": rollout_curr_node.active_children[i].answer,
                    },
                }
            active_rollout_nodes = list(rollout_nodes.keys())
            while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth:
                # r_time = time.time()
                infer_requests = [
                    InferRequest(
                        system_message
                        + [
                            prompt_message,
                            rollout_nodes[index]["history_messages"],
                            next_message,
                        ]
                    )
                    for index in active_rollout_nodes
                ]
                # logger.info(f"rollout.prepare time: {time.time() - r_time}")
                # r_time = time.time()
                rollout_iter_index = 0
                while True:
                    responses = perform_infer(
                        self.infer_engine,
                        infer_requests,
                        self.rollout_request_configs,
                        **self.infer_kwargs,
                    )
                    if len(responses) > 0:
                        break
                    if rollout_iter_index == 5:
                        raise ValueError("Rollout should not return any response")
                    rollout_iter_index += 1
                # logger.info(f"rollout.infer time: {time.time() - r_time}")

                # r_time = time.time()
                orm_infer_requests = []
                end_paths = []
                for index, response in zip(active_rollout_nodes, responses):
                    self.update_usage_info(response)
                    output = (
                        response.choices[0]
                        .message.content.rstrip(sep_token + "\n")
                        .split(sep_token)[0]
                        + sep_token
                        + "\n"
                    )
                    rollout_nodes[index]["history_messages"]["content"] += output
                    end_paths.append(
                        rollout_nodes[index]["history_messages"]["content"]
                    )
                    orm_infer_requests.append(
                        InferRequest([rollout_nodes[index]["history_messages"]])
                    )
                # logger.info(f"rollout.orm_prepare time: {time.time() - r_time}")

                # r_time = time.time()
                orm_score, _orm_mask = get_reward(
                    self.orm_model,
                    orm_infer_requests,
                    ground_truths=[ground_truth] * len(infer_requests),
                    threshold=0.0,
                )
                # logger.info(f"rollout.get_orm time: {time.time() - r_time}")
                terminated_state = self.orm_model.check_terminate(end_paths)
                for index, score, terminated in zip(
                    active_rollout_nodes, orm_score, terminated_state
                ):
                    if terminated:
                        rollout_curr_node.active_children[index].init_and_update_value(
                            score
                        )
                        if score > 0.9:
                            rollout_correct_answers.append(
                                rollout_nodes[index]["history_messages"]["content"]
                            )
                        else:
                            rollout_incorrect_answers.append(
                                rollout_nodes[index]["history_messages"]["content"]
                            )
                        rollout_nodes.pop(index)
                active_rollout_nodes = list(rollout_nodes.keys())
                rollout_depth += 1

        def _back_propagate(back_curr_node: LanguageNode):
            while back_curr_node:
                if back_curr_node == curr_node:
                    best_child_value = max(
                        [child.outcome_reward for child in back_curr_node.children]
                    )
                    back_curr_node.init_and_update_value(best_child_value)
                    last_child_value = back_curr_node.outcome_reward
                else:
                    back_curr_node.init_and_update_value(last_child_value)
                    last_child_value = back_curr_node.outcome_reward
                back_curr_node.visit()
                if len(back_curr_node.active_children) == 0:
                    back_curr_node.terminated = True
                    if not back_curr_node.is_root():
                        back_curr_node.parent.active_children.remove(back_curr_node)
                back_curr_node = back_curr_node.parent

        _args = self.args
        system_message = [] + _args.system_message
        sep_token = _args.stop_words[0] + "\n"
        _root = LanguageNode(sep_token=sep_token)
        prompt_message = {
            "role": "user",
            "content": query,
        }

        rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = (
            [],
            [],
            [],
        )
        iter_count = 0
        stop_reason = None
        while True:
            logger.info(f"iter_count: {iter_count}" + "." * 10)
            s_time = time.time()
            curr_node = _select(_root)
            logger.debug("select" + "=" * 10 + f"time: {time.time() - s_time}")
            s_time = time.time()
            _expand(curr_node)
            logger.debug("expand" + "=" * 10 + f"time: {time.time() - s_time}")
            if curr_node.depth > _args.rollout_start_depth:
                s_time = time.time()
                _rollout(curr_node)
                logger.debug("rollout" + "=" * 10 + f"time: {time.time() - s_time}")
            s_time = time.time()
            _back_propagate(curr_node)
            logger.debug("back propagate" + "=" * 10 + f"time: {time.time() - s_time}")
            if (
                len(rollout_correct_answers) + len(rollout_incorrect_answers)
                >= 2 * _args.num_return_sequences
            ):
                if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers):
                    stop_reason = "too easy"
                    break
                elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers):
                    stop_reason = "too hard"
                    break
            if _root.terminated:
                stop_reason = "root terminated"
                break
            if len(terminated_nodes) >= _args.num_return_sequences:
                stop_reason = "enough nodes"
                break
            if iter_count >= _args.max_iterations:
                stop_reason = "max_iterations"
                break
            iter_count += 1
        logger.info(f"stop_reason: {stop_reason}")
        # logger.info(f"rollout_correct_answers: {rollout_correct_answers}")
        # logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}")

        monte_carlo_tree = _root.collect()
        result = {
            "query": query,
            "ground_truth": ground_truth,
            "rollout_correct_answers": rollout_correct_answers,
            "rollout_incorrect_answers": rollout_incorrect_answers,
            "monte_carlo_tree": monte_carlo_tree,
        }
        result_json = json.dumps(result, ensure_ascii=False)
        logger.info(result_json)
        return result_json

    def do_sample(self, data):
        if not isinstance(data, list):
            data = [data]
        generated = []
        for item in data:
            logger.info(f"time: {time.ctime(time.time())}")
            try:
                messages = item["messages"][0]
                query = messages[0]["content"]
                ground_truth = messages[1]["content"]
                generated.append(self.search_single(query, ground_truth) + "\n")
            except Exception as e:
                logger.error(f"Error: {e}")
                logger.error(f"Traceback: {traceback.format_exc()}")
        return generated
